{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Embeddings\n",
    "\n",
    "https://www.youtube.com/watch?v=wSXGlvTR9UM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/wSXGlvTR9UM?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import HTML\n",
    "\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/wSXGlvTR9UM?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Activation, Embedding, Merge, Flatten\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Education</th>\n",
       "      <th>H_education</th>\n",
       "      <th>num_child</th>\n",
       "      <th>Religion</th>\n",
       "      <th>Employ</th>\n",
       "      <th>H_occupation</th>\n",
       "      <th>living_standard</th>\n",
       "      <th>Media_exposure</th>\n",
       "      <th>contraceptive</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>24</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>45</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>43</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>42</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>36</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Age  Education  H_education  num_child  Religion  Employ  H_occupation  \\\n",
       "0   24          2            3          3         1       1             2   \n",
       "1   45          1            3         10         1       1             3   \n",
       "2   43          2            3          7         1       1             3   \n",
       "3   42          3            2          9         1       1             3   \n",
       "4   36          3            3          8         1       1             3   \n",
       "\n",
       "   living_standard  Media_exposure  contraceptive  \n",
       "0                3               0              1  \n",
       "1                4               0              1  \n",
       "2                4               0              1  \n",
       "3                3               0              1  \n",
       "4                2               0              1  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('../data/cmc.data',header=None,names=['Age','Education','H_education',\n",
    "                                                     'num_child','Religion', 'Employ',\n",
    "                                                     'H_occupation','living_standard',\n",
    "                                                     'Media_exposure','contraceptive'])\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Age                False\n",
       "Education          False\n",
       "H_education        False\n",
       "num_child          False\n",
       "Religion           False\n",
       "Employ             False\n",
       "H_occupation       False\n",
       "living_standard    False\n",
       "Media_exposure     False\n",
       "contraceptive      False\n",
       "dtype: bool"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.isnull().any()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x11a6c2ba8>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAE4RJREFUeJzt3X+sX3ddx/Hn23bA7MV2MLw2bbVLXEgmFVhvthKIuZdFUzayLnHgyNy6ZaaJDsQw4wqJGvwRyx+ATA3aMLIO0csywdVu0yxdr8gfG7Y418FAyihZb8oqW1e4bGqqb//4fsDL5bbf8/3e8+33ez8+H8nNPedzPud8P+/zuX3d0/P9cSMzkSTV60eGPQBJ0mAZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVa5R0EfEmoi4NyK+HBFPRsQbIuIVEfFQRHy1fL+g9I2IuCMijkTE4xFx6WBLkCSdTTR5Z2xE7AH+KTM/FhEvAX4UeB/wXGbuioidwAWZeXtEXAm8C7gSuBz4SGZefrbjX3jhhblx48a+Cvjud7/LqlWr+tp31FjL6KmlDrCWUbWUWg4dOvStzHxV146ZedYvYDXwdcovhXntXwHWluW1wFfK8l8A71is35m+Nm/enP06cOBA3/uOGmsZPbXUkWkto2optQAHs0uGZ2b3K/qIeB2wG/gS8FrgEPBuYDYz15Q+AZzMzDURsQ/YlZmfK9v2A7dn5sEFx90B7AAYHx/fPD093fWX0mLm5uYYGxvra99RYy2jp5Y6wFpG1VJqmZqaOpSZE107dvtNAEwAp4HLy/pHgN8Hnl/Q72T5vg9407z2/cDE2R7DK/oOaxk9tdSRaS2j6lxc0Td5MvYYcCwzHy3r9wKXAs9ExFqA8v1E2T4LbJi3//rSJkkagq5Bn5nfBJ6OiFeXpivo3MbZC2wvbduB+8ryXuDG8uqbLcCpzDze7rAlSU2tbNjvXcAnyytungJupvNL4p6IuAX4BvD20vcBOq+4OQK8UPpKkoakUdBn5mN07tUvdMUifRO4dYnjkiS1xHfGSlLlDHpJqpxBL0mVa/pkrCRVa+PO+4f22HdtHfxHOXhFL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlGgV9RByNiMMR8VhEHCxtr4iIhyLiq+X7BaU9IuKOiDgSEY9HxKWDLECSdHa9XNFPZebrMnOirO8E9mfmxcD+sg7wFuDi8rUD+Ghbg5Uk9W4pt262AXvK8h7gmnntd2fHI8CaiFi7hMeRJC1BZGb3ThFfB04CCfxFZu6OiOczc03ZHsDJzFwTEfuAXZn5ubJtP3B7Zh5ccMwddK74GR8f3zw9Pd1XAXNzc4yNjfW176ixltFTSx1gLWdzePZUa8fq1UWrV/Rdy9TU1KF5d1nOaGXD470pM2cj4seBhyLiy/M3ZmZGRPffGD+4z25gN8DExEROTk72svv3zczM0O++o8ZaRk8tdYC1nM1NO+9v7Vi9umvrqoHPS6NbN5k5W76fAD4DXAY8871bMuX7idJ9Ftgwb/f1pU2SNARdgz4iVkXEy7+3DPwC8ASwF9heum0H7ivLe4Eby6tvtgCnMvN46yOXJDXS5NbNOPCZzm14VgJ/lZl/HxH/DNwTEbcA3wDeXvo/AFwJHAFeAG5ufdSSpMa6Bn1mPgW8dpH2Z4ErFmlP4NZWRidJWjLfGStJlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVW5l044RsQI4CMxm5lsj4iJgGnglcAi4ITP/KyJeCtwNbAaeBX4pM4+2PnKpcht33t/q8W7bdJqbGhzz6K6rWn1cDV8vV/TvBp6ct/4B4MOZ+dPASeCW0n4LcLK0f7j0kyQNSaOgj4j1wFXAx8p6AG8G7i1d9gDXlOVtZZ2y/YrSX5I0BJGZ3TtF3Av8EfBy4DeBm4BHylU7EbEBeDAzXxMRTwBbM/NY2fY14PLM/NaCY+4AdgCMj49vnp6e7quAubk5xsbG+tp31FjL6BlmHYdnT7V6vPHz4ZkXu/fbtG51q487CG3PS9vnuhcXrV7Rdy1TU1OHMnOiW7+u9+gj4q3Aicw8FBGTfY1mEZm5G9gNMDExkZOT/R16ZmaGfvcdNdYyeoZZR5P76b24bdNpPni4+9NyR6+fbPVxB6HteWn7XPfirq2rBv4z1uTJ2DcCV0fElcDLgB8DPgKsiYiVmXkaWA/Mlv6zwAbgWESsBFbTeVJWkjQEXe/RZ+Z7M3N9Zm4ErgMezszrgQPAtaXbduC+sry3rFO2P5xN7g9JkgZiKa+jvx14T0QcofMSyztL+53AK0v7e4CdSxuiJGkpGr+OHiAzZ4CZsvwUcNkiff4DeFsLY5MktcB3xkpS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqlxPf0pQ2rjz/kb9btt0mpsa9m3i6K6rWjuW9P+NV/SSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFWua9BHxMsi4vMR8a8R8cWIeH9pvygiHo2IIxHxqYh4SWl/aVk/UrZvHGwJkqSzaXJF/5/AmzPztcDrgK0RsQX4APDhzPxp4CRwS+l/C3CytH+49JMkDUnXoM+OubJ6XvlK4M3AvaV9D3BNWd5W1inbr4iIaG3EkqSeNLpHHxErIuIx4ATwEPA14PnMPF26HAPWleV1wNMAZfsp4JVtDlqS1FxkZvPOEWuAzwC/DdxVbs8QERuABzPzNRHxBLA1M4+VbV8DLs/Mby041g5gB8D4+Pjm6enpvgqYm5tjbGysr31HzXKo5fDsqUb9xs+HZ15s73E3rVvd3sF6MMw5aXqum2o6J8M6171oe17aPte9uGj1ir5rmZqaOpSZE9369fShZpn5fEQcAN4ArImIleWqfT0wW7rNAhuAYxGxElgNPLvIsXYDuwEmJiZycnKyl6F838zMDP3uO2qWQy1NP6jstk2n+eDh9j4z7+j1k60dqxfDnJM2PxQOms/JsM51L9qel7bPdS/u2rpq4D9jTV5186pyJU9EnA/8PPAkcAC4tnTbDtxXlveWdcr2h7OX/zZIklrV5JJrLbAnIlbQ+cVwT2bui4gvAdMR8QfAvwB3lv53Ap+IiCPAc8B1Axi3JKmhrkGfmY8Dr1+k/SngskXa/wN4WyujkyQtme+MlaTKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIq1zXoI2JDRByIiC9FxBcj4t2l/RUR8VBEfLV8v6C0R0TcERFHIuLxiLh00EVIks6syRX9aeC2zLwE2ALcGhGXADuB/Zl5MbC/rAO8Bbi4fO0APtr6qCVJjXUN+sw8nplfKMvfAZ4E1gHbgD2l2x7gmrK8Dbg7Ox4B1kTE2tZHLklqpKd79BGxEXg98CgwnpnHy6ZvAuNleR3w9LzdjpU2SdIQRGY26xgxBvwj8IeZ+emIeD4z18zbfjIzL4iIfcCuzPxcad8P3J6ZBxccbwedWzuMj49vnp6e7quAubk5xsbG+tp31CyHWg7PnmrUb/x8eObF9h5307rV7R2sB8Ock6bnuqmmczKsc92Ltuel7XPdi4tWr+i7lqmpqUOZOdGt38omB4uI84C/AT6ZmZ8uzc9ExNrMPF5uzZwo7bPAhnm7ry9tPyAzdwO7ASYmJnJycrLJUH7IzMwM/e47apZDLTftvL9Rv9s2neaDhxv9eDVy9PrJ1o7Vi2HOSdNz3VTTORnWue5F2/PS9rnuxV1bVw38Z6zJq24CuBN4MjM/NG/TXmB7Wd4O3Dev/cby6pstwKl5t3gkSedYk0uuNwI3AIcj4rHS9j5gF3BPRNwCfAN4e9n2AHAlcAR4Abi51RFLknrSNejLvfY4w+YrFumfwK1LHJckqSW+M1aSKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlWvvA8OH5PDsqaF9lvTRXVcN5XElqRde0UtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekynUN+oj4eESciIgn5rW9IiIeioivlu8XlPaIiDsi4khEPB4Rlw5y8JKk7ppc0d8FbF3QthPYn5kXA/vLOsBbgIvL1w7go+0MU5LUr65Bn5mfBZ5b0LwN2FOW9wDXzGu/OzseAdZExNq2BitJ6l2/9+jHM/N4Wf4mMF6W1wFPz+t3rLRJkoYkMrN7p4iNwL7MfE1Zfz4z18zbfjIzL4iIfcCuzPxcad8P3J6ZBxc55g46t3cYHx/fPD093VcBJ547xTMv9rXrkm1at7rV483NzTE2NtbqMdt2ePZUo37j59PqvLR9rpsa5pw0PddNNZ2TYZ3rXrQ9L22f615ctHpF37VMTU0dysyJbv36/Zuxz0TE2sw8Xm7NnCjts8CGef3Wl7Yfkpm7gd0AExMTOTk52ddA/uST9/HBw8P507dHr59s9XgzMzP0ex7OlaZ/n/e2TadbnZe2z3VTw5yTtv8WctM5Gda57kXb8zKsvzsNcNfWVQP/Gev31s1eYHtZ3g7cN6/9xvLqmy3AqXm3eCRJQ9D113tE/DUwCVwYEceA3wV2AfdExC3AN4C3l+4PAFcCR4AXgJsHMGZJUg+6Bn1mvuMMm65YpG8Cty51UJKk9vjOWEmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcgMJ+ojYGhFfiYgjEbFzEI8hSWqm9aCPiBXAnwFvAS4B3hERl7T9OJKkZgZxRX8ZcCQzn8rM/wKmgW0DeBxJUgODCPp1wNPz1o+VNknSEERmtnvAiGuBrZn5K2X9BuDyzHzngn47gB1l9dXAV/p8yAuBb/W576ixltFTSx1gLaNqKbX8VGa+qlunlX0e/GxmgQ3z1teXth+QmbuB3Ut9sIg4mJkTSz3OKLCW0VNLHWAto+pc1DKIWzf/DFwcERdFxEuA64C9A3gcSVIDrV/RZ+bpiHgn8A/ACuDjmfnFth9HktTMIG7dkJkPAA8M4tiLWPLtnxFiLaOnljrAWkbVwGtp/clYSdJo8SMQJKlyyyLoI+LjEXEiIp44w/aIiDvKRy48HhGXnusxNtWglsmIOBURj5Wv3znXY2wqIjZExIGI+FJEfDEi3r1In5Gfm4Z1LIt5iYiXRcTnI+JfSy3vX6TPSyPiU2VOHo2Ijed+pN01rOWmiPj3efPyK8MYaxMRsSIi/iUi9i2ybbBzkpkj/wX8HHAp8MQZtl8JPAgEsAV4dNhjXkItk8C+YY+zYS1rgUvL8suBfwMuWW5z07COZTEv5TyPleXzgEeBLQv6/Brw52X5OuBTwx73Emq5CfjTYY+1YT3vAf5qsZ+jQc/Jsriiz8zPAs+dpcs24O7seARYExFrz83oetOglmUjM49n5hfK8neAJ/nhd0GP/Nw0rGNZKOd5rqyeV74WPhG3DdhTlu8FroiIOEdDbKxhLctCRKwHrgI+doYuA52TZRH0DdT2sQtvKP9dfTAifmbYg2mi/Ffz9XSuuuZbVnNzljpgmcxLuUXwGHACeCgzzzgnmXkaOAW88tyOspkGtQD8YrkteG9EbFhk+yj4Y+C3gP85w/aBzkktQV+TL9B5W/NrgT8B/nbI4+kqIsaAvwF+IzO/Pezx9KtLHctmXjLzvzPzdXTelX5ZRLxm2GPqV4Na/g7YmJk/CzzE/10Vj4yIeCtwIjMPDWsMtQR9o49dWA4y89vf++9qdt6PcF5EXDjkYZ1RRJxHJxw/mZmfXqTLspibbnUst3kByMzngQPA1gWbvj8nEbESWA08e25H15sz1ZKZz2bmf5bVjwGbz/XYGngjcHVEHKXzab5vjoi/XNBnoHNSS9DvBW4sr/DYApzKzOPDHlQ/IuInvndvLiIuozNHI/mPsIzzTuDJzPzQGbqN/Nw0qWO5zEtEvCoi1pTl84GfB768oNteYHtZvhZ4OMuzgKOkSS0Lnu+5ms7zKyMlM9+bmeszcyOdJ1ofzsxfXtBtoHMykHfGti0i/prOqx4ujIhjwO/SeWKGzPxzOu/CvRI4ArwA3DyckXbXoJZrgV+NiNPAi8B1o/iPsHgjcANwuNxHBXgf8JOwrOamSR3LZV7WAnui8weAfgS4JzP3RcTvAQczcy+dX2qfiIgjdF4YcN3whntWTWr59Yi4GjhNp5abhjbaHp3LOfGdsZJUuVpu3UiSzsCgl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcv8LAfcigdNBX0MAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x11a6a25f8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df.Education.hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1473, 10)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x11a7ca630>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAFJxJREFUeJzt3X+MndV95/H3tzYkXU/WAyGdtWxv7VWtXZHQJHhESRNVM0G7NU67ZqUUEaHGIEuWWrZK1f2BW6mt2u5K5A82TdBuulaJMJWbAdGwtgzJFjme7WYjnOKUYH4kmwlxikfEVrCZZgJt5fTbP+4hXIzH89w795eP3i/pap7nnHPv/T4Ph88899y515GZSJLq9WPDLkCS1F8GvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyq4ddAMBVV12VmzZt6uq+P/jBD1izZk1vC+oB6+qMdXVuVGuzrs6spK5jx459LzPfsezAzBz6bevWrdmtI0eOdH3ffrKuzlhX50a1NuvqzErqAp7IBhnr0o0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFVuJL4CYSWOzy9w255HhvLcJ+760FCeV5I64RW9JFXOoJekyhn0klQ5g16SKmfQS1LlGgV9RIxHxEMR8fWIeC4i3hcRV0bEYxHxzfLzijI2IuJTETEXEU9FxLX9PQRJ0sU0vaL/JPCFzPxXwLuB54A9wOHM3AIcLvsANwJbym038OmeVixJ6siyQR8Ra4GfA+4FyMy/z8yXgR3AvjJsH3BT2d4B3F/+AZTHgfGIWNfzyiVJjUTrX6O6yICI9wB7gWdpXc0fAz4GzGfmeBkTwNnMHI+IQ8Bdmfml0ncYuDMznzjvcXfTuuJnYmJi68zMTFcHcPrMAqde7equK3bN+rVL9i0uLjI2NjbAapqxrs6Mal0wurVZV2dWUtf09PSxzJxcblyTT8auBq4Ffi0zj0bEJ3l9mQaAzMyIuPhvjPNk5l5av0CYnJzMqampTu7+I/fsP8Ddx4fzAd8Tt04t2Tc7O0u3x9RP1tWZUa0LRrc26+rMIOpqskZ/EjiZmUfL/kO0gv/Ua0sy5efp0j8PbGy7/4bSJkkagmWDPjO/C7wQEf+yNN1AaxnnILCztO0EDpTtg8BHy1/fXA8sZOaLvS1bktRU0zWPXwP2R8TlwPPA7bR+STwYEbuA7wA3l7GPAtuBOeCVMlaSNCSNgj4znwQutOB/wwXGJnDHCuuSJPWIn4yVpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVbnWTQRFxAvg+8EPgXGZORsSVwAPAJuAEcHNmno2IAD4JbAdeAW7LzK/2vnRJ6o1Nex4Z2nPft21N35+jkyv66cx8T2ZOlv09wOHM3AIcLvsANwJbym038OleFStJ6txKlm52APvK9j7gprb2+7PlcWA8Itat4HkkSSsQmbn8oIhvA2eBBP5nZu6NiJczc7z0B3A2M8cj4hBwV2Z+qfQdBu7MzCfOe8zdtK74mZiY2DozM9PVAZw+s8CpV7u664pds37tkn2Li4uMjY0NsJpmrKszo1oXjG5tl2Jdx+cXBlzN6zavXdX1+Zqenj7WtsqypEZr9MAHMnM+In4CeCwivt7emZkZEcv/xnjjffYCewEmJydzamqqk7v/yD37D3D38aaH0Vsnbp1asm92dpZuj6mfrKszo1oXjG5tl2Jdtw15jb7f56vR0k1mzpefp4GHgeuAU68tyZSfp8vweWBj2903lDZJ0hAsG/QRsSYi3vbaNvBvgKeBg8DOMmwncKBsHwQ+Gi3XAwuZ+WLPK5ckNdJkzWMCeLi1DM9q4E8z8wsR8ZfAgxGxC/gOcHMZ/yitP62co/Xnlbf3vGpJUmPLBn1mPg+8+wLtLwE3XKA9gTt6Up0kacX8ZKwkVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlWsc9BGxKiL+KiIOlf3NEXE0IuYi4oGIuLy0v6Xsz5X+Tf0pXZLURCdX9B8Dnmvb/zjwicz8KeAssKu07wLOlvZPlHGSpCFpFPQRsQH4EPDHZT+ADwIPlSH7gJvK9o6yT+m/oYyXJA1B0yv6PwT+M/APZf/twMuZea7snwTWl+31wAsApX+hjJckDUFk5sUHRPwCsD0zfzUipoD/CNwGPF6WZ4iIjcDnM/NdEfE0sC0zT5a+bwE/k5nfO+9xdwO7ASYmJrbOzMx0dQCnzyxw6tWu7rpi16xfu2Tf4uIiY2NjA6ymGevqzKjWBaNb26VY1/H5hQFX87rNa1d1fb6mp6ePZebkcuNWN3is9wP/NiK2A28F/inwSWA8IlaXq/YNwHwZPw9sBE5GxGpgLfDS+Q+amXuBvQCTk5M5NTXVoJQ3u2f/Ae4+3uQweu/ErVNL9s3OztLtMfWTdXVmVOuC0a3tUqzrtj2PDLaYNvdtW9P387Xs0k1m/mZmbsjMTcAtwBcz81bgCPDhMmwncKBsHyz7lP4v5nIvGyRJfbOSv6O/E/iNiJijtQZ/b2m/F3h7af8NYM/KSpQkrURHax6ZOQvMlu3ngesuMOZvgV/qQW2SpB7wk7GSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUueH8q9rSJeL4/MLQ/uHoE3d9aCjPq/p4RS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqt2zQR8RbI+IrEfG1iHgmIn6vtG+OiKMRMRcRD0TE5aX9LWV/rvRv6u8hSJIupskV/d8BH8zMdwPvAbZFxPXAx4FPZOZPAWeBXWX8LuBsaf9EGSdJGpJlgz5bFsvuZeWWwAeBh0r7PuCmsr2j7FP6b4iI6FnFkqSONFqjj4hVEfEkcBp4DPgW8HJmnitDTgLry/Z64AWA0r8AvL2XRUuSmovMbD44Yhx4GPht4L6yPENEbAQ+n5nvioingW2ZebL0fQv4mcz83nmPtRvYDTAxMbF1ZmamqwM4fWaBU692ddcVu2b92iX7FhcXGRsbG2A1zVhXZ0Z1fsHonrNLsa7j8wsDruZ1m9eu6vp8TU9PH8vMyeXGdfRdN5n5ckQcAd4HjEfE6nLVvgGYL8PmgY3AyYhYDawFXrrAY+0F9gJMTk7m1NRUJ6X8yD37D3D38eF8Zc+JW6eW7JudnaXbY+on6+rMqM4vGN1zdinWNazvMwK4b9uavp+vJn91845yJU9E/Djwr4HngCPAh8uwncCBsn2w7FP6v5idvGyQJPVUk0uVdcC+iFhF6xfDg5l5KCKeBWYi4r8AfwXcW8bfC/xJRMwBZ4Bb+lC3JKmhZYM+M58C3nuB9ueB6y7Q/rfAL/WkOknSivnJWEmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuWWDfqI2BgRRyLi2Yh4JiI+VtqvjIjHIuKb5ecVpT0i4lMRMRcRT0XEtf0+CEnS0ppc0Z8D/kNmXg1cD9wREVcDe4DDmbkFOFz2AW4EtpTbbuDTPa9aktTYskGfmS9m5lfL9veB54D1wA5gXxm2D7ipbO8A7s+Wx4HxiFjX88olSY10tEYfEZuA9wJHgYnMfLF0fReYKNvrgRfa7naytEmShiAys9nAiDHg/wD/NTM/FxEvZ+Z4W//ZzLwiIg4Bd2Xml0r7YeDOzHzivMfbTWtph4mJia0zMzNdHcDpMwucerWru67YNevXLtm3uLjI2NjYAKtpxro6M6rzC0b3nF2KdR2fXxhwNa/bvHZV1+drenr6WGZOLjdudZMHi4jLgD8D9mfm50rzqYhYl5kvlqWZ06V9HtjYdvcNpe0NMnMvsBdgcnIyp6ammpTyJvfsP8DdxxsdRs+duHVqyb7Z2Vm6PaZ+sq7OjOr8gtE9Z5diXbfteWSwxbS5b9uavp+vJn91E8C9wHOZ+d/aug4CO8v2TuBAW/tHy1/fXA8stC3xSJIGrMmlyvuBXwaOR8STpe23gLuAByNiF/Ad4ObS9yiwHZgDXgFu72nFkqSOLBv0Za09lui+4QLjE7hjhXVJknrET8ZKUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqt2zQR8RnIuJ0RDzd1nZlRDwWEd8sP68o7RERn4qIuYh4KiKu7WfxkqTlNbmivw/Ydl7bHuBwZm4BDpd9gBuBLeW2G/h0b8qUJHVr2aDPzL8AzpzXvAPYV7b3ATe1td+fLY8D4xGxrlfFSpI61+0a/URmvli2vwtMlO31wAtt406WNknSkERmLj8oYhNwKDPfVfZfzszxtv6zmXlFRBwC7srML5X2w8CdmfnEBR5zN63lHSYmJrbOzMx0dQCnzyxw6tWu7rpi16xfu2Tf4uIiY2NjA6ymGevqzKjOLxjdc3Yp1nV8fmHA1bxu89pVXZ+v6enpY5k5udy41V09OpyKiHWZ+WJZmjld2ueBjW3jNpS2N8nMvcBegMnJyZyamuqqkHv2H+Du490exsqcuHVqyb7Z2Vm6PaZ+sq7OjOr8gtE9Z5diXbfteWSwxbS5b9uavp+vbpduDgI7y/ZO4EBb+0fLX99cDyy0LfFIkoZg2UuViPgsMAVcFREngd8F7gIejIhdwHeAm8vwR4HtwBzwCnB7H2qWJHVg2aDPzI8s0XXDBcYmcMdKi5Ik9Y6fjJWkyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFWuL0EfEdsi4hsRMRcRe/rxHJKkZnoe9BGxCvjvwI3A1cBHIuLqXj+PJKmZflzRXwfMZebzmfn3wAywow/PI0lqoB9Bvx54oW3/ZGmTJA3B6mE9cUTsBnaX3cWI+EaXD3UV8L3eVNWZ+PhFu4dW1zKsqzOjOr/Ac9apkaxr+uMrqusnmwzqR9DPAxvb9jeUtjfIzL3A3pU+WUQ8kZmTK32cXrOuzlhX50a1NuvqzCDq6sfSzV8CWyJic0RcDtwCHOzD80iSGuj5FX1mnouIfw/8b2AV8JnMfKbXzyNJaqYva/SZ+SjwaD8e+wJWvPzTJ9bVGevq3KjWZl2d6XtdkZn9fg5J0hD5FQiSVLmRDfqI+ExEnI6Ip5foj4j4VPmahaci4tq2vp0R8c1y2zngum4t9RyPiC9HxLvb+k6U9icj4okB1zUVEQvluZ+MiN9p6+vbV1Y0qOs/tdX0dET8MCKuLH39PF8bI+JIRDwbEc9ExMcuMGbgc6xhXQOfYw3rGvgca1jXwOdYRLw1Ir4SEV8rdf3eBca8JSIeKOfkaERsauv7zdL+jYj4+RUXlJkjeQN+DrgWeHqJ/u3A54EArgeOlvYrgefLzyvK9hUDrOtnX3s+Wl8DcbSt7wRw1ZDO1xRw6ALtq4BvAf8CuBz4GnD1oOo6b+wvAl8c0PlaB1xbtt8G/P/zj3sYc6xhXQOfYw3rGvgca1LXMOZYmTNjZfsy4Chw/XljfhX4o7J9C/BA2b66nKO3AJvLuVu1knpG9oo+M/8COHORITuA+7PlcWA8ItYBPw88lplnMvMs8BiwbVB1ZeaXy/MCPE7rcwR91+B8LaWvX1nRYV0fAT7bq+e+mMx8MTO/Wra/DzzHmz/BPfA51qSuYcyxhudrKX2bY13UNZA5VubMYtm9rNzOf0N0B7CvbD8E3BARUdpnMvPvMvPbwBytc9i1kQ36Bpb6qoVR+gqGXbSuCF+TwJ9HxLFofTJ40N5XXkp+PiLeWdpG4nxFxD+hFZZ/1tY8kPNVXjK/l9ZVV7uhzrGL1NVu4HNsmbqGNseWO1+DnmMRsSoingRO07owWHJ+ZeY5YAF4O304X0P7CoTaRcQ0rf8JP9DW/IHMnI+InwAei4ivlyveQfgq8JOZuRgR24H/BWwZ0HM38YvA/8vM9qv/vp+viBij9T/+r2fm3/TysVeiSV3DmGPL1DW0Odbwv+NA51hm/hB4T0SMAw9HxLsy84LvVfXbpXxFv9RXLTT6CoZ+ioifBv4Y2JGZL73Wnpnz5edp4GFW+HKsE5n5N6+9lMzW5xwui4irGIHzVdzCeS+p+32+IuIyWuGwPzM/d4EhQ5ljDeoayhxbrq5hzbEm56sY+Bwrj/0ycIQ3L+/96LxExGpgLfAS/ThfvXwDotc3YBNLv7n4Id74RtlXSvuVwLdpvUl2Rdm+coB1/XNaa2o/e177GuBtbdtfBrYNsK5/xuufm7gO+Oty7lbTejNxM6+/UfbOQdVV+tfSWsdfM6jzVY79fuAPLzJm4HOsYV0Dn2MN6xr4HGtS1zDmGPAOYLxs/zjwf4FfOG/MHbzxzdgHy/Y7eeObsc+zwjdjR3bpJiI+S+td/Ksi4iTwu7Te0CAz/4jWJ2+305rwrwC3l74zEfEHtL5zB+D3840v1fpd1+/QWmf7H633VTiXrS8smqD18g1aE/9PM/MLA6zrw8CvRMQ54FXglmzNqr5+ZUWDugD+HfDnmfmDtrv29XwB7wd+GThe1lEBfotWiA5zjjWpaxhzrEldw5hjTeqCwc+xdcC+aP1DTD9GK8QPRcTvA09k5kHgXuBPImKO1i+hW0rNz0TEg8CzwDngjmwtA3XNT8ZKUuUu5TV6SVIDBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZX7RwDj7+S1+Ez4AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x11a7d43c8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df.contraceptive.hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Age                int64\n",
       "Education          int64\n",
       "H_education        int64\n",
       "num_child          int64\n",
       "Religion           int64\n",
       "Employ             int64\n",
       "H_occupation       int64\n",
       "living_standard    int64\n",
       "Media_exposure     int64\n",
       "contraceptive      int64\n",
       "dtype: object"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def one_hot_encoding(idx):\n",
    "    y = np.zeros((len(idx),max(idx)+1))\n",
    "    y[np.arange(len(idx)), idx] = 1\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = StandardScaler()\n",
    "df[['Age','num_child']] = scaler.fit_transform(df[['Age','num_child']]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = df[['Age','num_child','Employ','Media_exposure']].values\n",
    "y = one_hot_encoding(df.contraceptive.values-1)\n",
    "\n",
    "liv_cats = df.living_standard.max()\n",
    "edu_cats = df.Education.max()\n",
    "\n",
    "liv = df.living_standard.values - 1\n",
    "liv_one_hot = one_hot_encoding(liv)\n",
    "edu = df.Education.values - 1\n",
    "edu_one_hot = one_hot_encoding(edu)\n",
    "\n",
    "train_x, test_x, train_liv, \\\n",
    "test_liv, train_edu, test_edu, train_y, test_y = train_test_split(x,liv_one_hot,edu_one_hot,y,test_size=0.1, random_state=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = np.hstack([train_x, train_edu, train_liv])\n",
    "test_x = np.hstack([test_x, test_edu, test_liv])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1325, 12)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1325, 4)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_edu.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1325, 4)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_liv.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1325, 12)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:2: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(input_dim=12, units=12)`\n",
      "  \n",
      "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:4: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n",
      "  after removing the cwd from sys.path.\n",
      "/usr/local/lib/python3.6/site-packages/keras/models.py:942: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
      "  warnings.warn('The `nb_epoch` argument in `fit` '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      " - 0s - loss: 1.1417 - acc: 0.3155\n",
      "Epoch 2/100\n",
      " - 0s - loss: 1.0537 - acc: 0.4060\n",
      "Epoch 3/100\n",
      " - 0s - loss: 1.0294 - acc: 0.4408\n",
      "Epoch 4/100\n",
      " - 0s - loss: 1.0140 - acc: 0.4679\n",
      "Epoch 5/100\n",
      " - 0s - loss: 1.0023 - acc: 0.4815\n",
      "Epoch 6/100\n",
      " - 0s - loss: 0.9930 - acc: 0.4921\n",
      "Epoch 7/100\n",
      " - 0s - loss: 0.9852 - acc: 0.5019\n",
      "Epoch 8/100\n",
      " - 0s - loss: 0.9786 - acc: 0.5102\n",
      "Epoch 9/100\n",
      " - 0s - loss: 0.9733 - acc: 0.5177\n",
      "Epoch 10/100\n",
      " - 0s - loss: 0.9684 - acc: 0.5185\n",
      "Epoch 11/100\n",
      " - 0s - loss: 0.9642 - acc: 0.5200\n",
      "Epoch 12/100\n",
      " - 0s - loss: 0.9605 - acc: 0.5260\n",
      "Epoch 13/100\n",
      " - 0s - loss: 0.9573 - acc: 0.5411\n",
      "Epoch 14/100\n",
      " - 0s - loss: 0.9543 - acc: 0.5426\n",
      "Epoch 15/100\n",
      " - 0s - loss: 0.9518 - acc: 0.5472\n",
      "Epoch 16/100\n",
      " - 0s - loss: 0.9494 - acc: 0.5487\n",
      "Epoch 17/100\n",
      " - 0s - loss: 0.9471 - acc: 0.5487\n",
      "Epoch 18/100\n",
      " - 0s - loss: 0.9451 - acc: 0.5494\n",
      "Epoch 19/100\n",
      " - 0s - loss: 0.9433 - acc: 0.5532\n",
      "Epoch 20/100\n",
      " - 0s - loss: 0.9415 - acc: 0.5540\n",
      "Epoch 21/100\n",
      " - 0s - loss: 0.9399 - acc: 0.5502\n",
      "Epoch 22/100\n",
      " - 0s - loss: 0.9385 - acc: 0.5562\n",
      "Epoch 23/100\n",
      " - 0s - loss: 0.9372 - acc: 0.5532\n",
      "Epoch 24/100\n",
      " - 0s - loss: 0.9359 - acc: 0.5509\n",
      "Epoch 25/100\n",
      " - 0s - loss: 0.9347 - acc: 0.5570\n",
      "Epoch 26/100\n",
      " - 0s - loss: 0.9335 - acc: 0.5540\n",
      "Epoch 27/100\n",
      " - 0s - loss: 0.9325 - acc: 0.5562\n",
      "Epoch 28/100\n",
      " - 0s - loss: 0.9315 - acc: 0.5630\n",
      "Epoch 29/100\n",
      " - 0s - loss: 0.9305 - acc: 0.5660\n",
      "Epoch 30/100\n",
      " - 0s - loss: 0.9296 - acc: 0.5683\n",
      "Epoch 31/100\n",
      " - 0s - loss: 0.9288 - acc: 0.5638\n",
      "Epoch 32/100\n",
      " - 0s - loss: 0.9280 - acc: 0.5668\n",
      "Epoch 33/100\n",
      " - 0s - loss: 0.9271 - acc: 0.5660\n",
      "Epoch 34/100\n",
      " - 0s - loss: 0.9264 - acc: 0.5698\n",
      "Epoch 35/100\n",
      " - 0s - loss: 0.9256 - acc: 0.5675\n",
      "Epoch 36/100\n",
      " - 0s - loss: 0.9250 - acc: 0.5691\n",
      "Epoch 37/100\n",
      " - 0s - loss: 0.9242 - acc: 0.5698\n",
      "Epoch 38/100\n",
      " - 0s - loss: 0.9236 - acc: 0.5743\n",
      "Epoch 39/100\n",
      " - 0s - loss: 0.9230 - acc: 0.5736\n",
      "Epoch 40/100\n",
      " - 0s - loss: 0.9223 - acc: 0.5751\n",
      "Epoch 41/100\n",
      " - 0s - loss: 0.9218 - acc: 0.5758\n",
      "Epoch 42/100\n",
      " - 0s - loss: 0.9211 - acc: 0.5751\n",
      "Epoch 43/100\n",
      " - 0s - loss: 0.9206 - acc: 0.5736\n",
      "Epoch 44/100\n",
      " - 0s - loss: 0.9199 - acc: 0.5751\n",
      "Epoch 45/100\n",
      " - 0s - loss: 0.9194 - acc: 0.5758\n",
      "Epoch 46/100\n",
      " - 0s - loss: 0.9190 - acc: 0.5804\n",
      "Epoch 47/100\n",
      " - 0s - loss: 0.9184 - acc: 0.5766\n",
      "Epoch 48/100\n",
      " - 0s - loss: 0.9180 - acc: 0.5789\n",
      "Epoch 49/100\n",
      " - 0s - loss: 0.9175 - acc: 0.5774\n",
      "Epoch 50/100\n",
      " - 0s - loss: 0.9170 - acc: 0.5796\n",
      "Epoch 51/100\n",
      " - 0s - loss: 0.9166 - acc: 0.5789\n",
      "Epoch 52/100\n",
      " - 0s - loss: 0.9162 - acc: 0.5766\n",
      "Epoch 53/100\n",
      " - 0s - loss: 0.9157 - acc: 0.5758\n",
      "Epoch 54/100\n",
      " - 0s - loss: 0.9153 - acc: 0.5751\n",
      "Epoch 55/100\n",
      " - 0s - loss: 0.9149 - acc: 0.5766\n",
      "Epoch 56/100\n",
      " - 0s - loss: 0.9146 - acc: 0.5766\n",
      "Epoch 57/100\n",
      " - 0s - loss: 0.9142 - acc: 0.5774\n",
      "Epoch 58/100\n",
      " - 0s - loss: 0.9139 - acc: 0.5789\n",
      "Epoch 59/100\n",
      " - 0s - loss: 0.9135 - acc: 0.5751\n",
      "Epoch 60/100\n",
      " - 0s - loss: 0.9132 - acc: 0.5758\n",
      "Epoch 61/100\n",
      " - 0s - loss: 0.9128 - acc: 0.5766\n",
      "Epoch 62/100\n",
      " - 0s - loss: 0.9124 - acc: 0.5758\n",
      "Epoch 63/100\n",
      " - 0s - loss: 0.9122 - acc: 0.5758\n",
      "Epoch 64/100\n",
      " - 0s - loss: 0.9118 - acc: 0.5743\n",
      "Epoch 65/100\n",
      " - 0s - loss: 0.9115 - acc: 0.5736\n",
      "Epoch 66/100\n",
      " - 0s - loss: 0.9111 - acc: 0.5743\n",
      "Epoch 67/100\n",
      " - 0s - loss: 0.9109 - acc: 0.5736\n",
      "Epoch 68/100\n",
      " - 0s - loss: 0.9106 - acc: 0.5774\n",
      "Epoch 69/100\n",
      " - 0s - loss: 0.9104 - acc: 0.5751\n",
      "Epoch 70/100\n",
      " - 0s - loss: 0.9101 - acc: 0.5721\n",
      "Epoch 71/100\n",
      " - 0s - loss: 0.9098 - acc: 0.5728\n",
      "Epoch 72/100\n",
      " - 0s - loss: 0.9096 - acc: 0.5751\n",
      "Epoch 73/100\n",
      " - 0s - loss: 0.9092 - acc: 0.5766\n",
      "Epoch 74/100\n",
      " - 0s - loss: 0.9091 - acc: 0.5751\n",
      "Epoch 75/100\n",
      " - 0s - loss: 0.9087 - acc: 0.5736\n",
      "Epoch 76/100\n",
      " - 0s - loss: 0.9085 - acc: 0.5743\n",
      "Epoch 77/100\n",
      " - 0s - loss: 0.9082 - acc: 0.5728\n",
      "Epoch 78/100\n",
      " - 0s - loss: 0.9081 - acc: 0.5721\n",
      "Epoch 79/100\n",
      " - 0s - loss: 0.9078 - acc: 0.5706\n",
      "Epoch 80/100\n",
      " - 0s - loss: 0.9075 - acc: 0.5713\n",
      "Epoch 81/100\n",
      " - 0s - loss: 0.9073 - acc: 0.5774\n",
      "Epoch 82/100\n",
      " - 0s - loss: 0.9071 - acc: 0.5736\n",
      "Epoch 83/100\n",
      " - 0s - loss: 0.9069 - acc: 0.5728\n",
      "Epoch 84/100\n",
      " - 0s - loss: 0.9066 - acc: 0.5736\n",
      "Epoch 85/100\n",
      " - 0s - loss: 0.9064 - acc: 0.5736\n",
      "Epoch 86/100\n",
      " - 0s - loss: 0.9062 - acc: 0.5743\n",
      "Epoch 87/100\n",
      " - 0s - loss: 0.9060 - acc: 0.5721\n",
      "Epoch 88/100\n",
      " - 0s - loss: 0.9058 - acc: 0.5713\n",
      "Epoch 89/100\n",
      " - 0s - loss: 0.9056 - acc: 0.5728\n",
      "Epoch 90/100\n",
      " - 0s - loss: 0.9053 - acc: 0.5736\n",
      "Epoch 91/100\n",
      " - 0s - loss: 0.9052 - acc: 0.5736\n",
      "Epoch 92/100\n",
      " - 0s - loss: 0.9050 - acc: 0.5728\n",
      "Epoch 93/100\n",
      " - 0s - loss: 0.9048 - acc: 0.5751\n",
      "Epoch 94/100\n",
      " - 0s - loss: 0.9046 - acc: 0.5743\n",
      "Epoch 95/100\n",
      " - 0s - loss: 0.9045 - acc: 0.5743\n",
      "Epoch 96/100\n",
      " - 0s - loss: 0.9043 - acc: 0.5721\n",
      "Epoch 97/100\n",
      " - 0s - loss: 0.9041 - acc: 0.5736\n",
      "Epoch 98/100\n",
      " - 0s - loss: 0.9039 - acc: 0.5728\n",
      "Epoch 99/100\n",
      " - 0s - loss: 0.9037 - acc: 0.5736\n",
      "Epoch 100/100\n",
      " - 0s - loss: 0.9035 - acc: 0.5713\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x11a96b898>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Dense(input_dim=train_x.shape[1],output_dim=12))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Dense(output_dim=3))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "model.compile(optimizer='adagrad', loss='categorical_crossentropy', metrics=['accuracy'])\n",
    "model.fit(train_x, train_y, nb_epoch=100, verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_1 (Dense)              (None, 12)                156       \n",
      "_________________________________________________________________\n",
      "activation_1 (Activation)    (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 3)                 39        \n",
      "_________________________________________________________________\n",
      "activation_2 (Activation)    (None, 3)                 0         \n",
      "=================================================================\n",
      "Total params: 195\n",
      "Trainable params: 195\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Weight Dimensions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(12, 12)\n",
      "(12,)\n",
      "(12, 3)\n",
      "(3,)\n"
     ]
    }
   ],
   "source": [
    "for w in model.get_weights():\n",
    "    print(w.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "148/148 [==============================] - 0s 160us/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.8492063283920288, 0.5810810923576355]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_x, test_y, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.36988726, 0.20566124, 0.4244515 ],\n",
       "       [0.5748099 , 0.21531923, 0.20987089],\n",
       "       [0.20616835, 0.26551756, 0.52831405],\n",
       "       [0.49942362, 0.24646108, 0.25411525],\n",
       "       [0.6958955 , 0.14855078, 0.15555373],\n",
       "       [0.17576644, 0.5469389 , 0.27729467],\n",
       "       [0.07824349, 0.4926772 , 0.4290793 ],\n",
       "       [0.75813717, 0.14813529, 0.09372754],\n",
       "       [0.7503536 , 0.03009113, 0.21955532],\n",
       "       [0.57615584, 0.14076254, 0.28308156]], dtype=float32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(test_x[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 3, 3, ..., 3, 1, 3])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "liv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x, test_x, train_liv, \\\n",
    "test_liv, train_edu, test_edu, train_y, test_y = train_test_split(x,liv,edu,y,test_size=0.1, random_state=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:16: UserWarning: The `Merge` layer is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n",
      "  app.launch_new_instance()\n",
      "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:18: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=12)`\n",
      "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:20: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n"
     ]
    }
   ],
   "source": [
    "# Input layer for religion\n",
    "encoder_liv = Sequential()\n",
    "encoder_liv.add(Embedding(liv_cats,4,input_length=1))\n",
    "encoder_liv.add(Flatten())\n",
    "\n",
    "# Input layer for religion\n",
    "encoder_edu = Sequential()\n",
    "encoder_edu.add(Embedding(edu_cats,4,input_length=1))\n",
    "encoder_edu.add(Flatten())\n",
    "\n",
    "# Input layer for triggers(x_b)\n",
    "dense_x = Sequential()\n",
    "dense_x.add(Dense(4, input_dim=x.shape[1]))\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Merge([encoder_liv, encoder_edu, dense_x], mode='concat'))\n",
    "# model.add(Activation('relu'))\n",
    "model.add(Dense(output_dim=12))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Dense(output_dim=3))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "model.compile(optimizer='adagrad', loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/site-packages/keras/models.py:942: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
      "  warnings.warn('The `nb_epoch` argument in `fit` '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      " - 1s - loss: 1.0389 - acc: 0.4445\n",
      "Epoch 2/100\n",
      " - 0s - loss: 1.0056 - acc: 0.4860\n",
      "Epoch 3/100\n",
      " - 0s - loss: 0.9912 - acc: 0.5042\n",
      "Epoch 4/100\n",
      " - 0s - loss: 0.9809 - acc: 0.4981\n",
      "Epoch 5/100\n",
      " - 0s - loss: 0.9733 - acc: 0.5087\n",
      "Epoch 6/100\n",
      " - 0s - loss: 0.9662 - acc: 0.5087\n",
      "Epoch 7/100\n",
      " - 0s - loss: 0.9598 - acc: 0.5132\n",
      "Epoch 8/100\n",
      " - 0s - loss: 0.9536 - acc: 0.5215\n",
      "Epoch 9/100\n",
      " - 0s - loss: 0.9479 - acc: 0.5208\n",
      "Epoch 10/100\n",
      " - 0s - loss: 0.9431 - acc: 0.5268\n",
      "Epoch 11/100\n",
      " - 0s - loss: 0.9389 - acc: 0.5351\n",
      "Epoch 12/100\n",
      " - 0s - loss: 0.9349 - acc: 0.5343\n",
      "Epoch 13/100\n",
      " - 0s - loss: 0.9319 - acc: 0.5358\n",
      "Epoch 14/100\n",
      " - 0s - loss: 0.9295 - acc: 0.5404\n",
      "Epoch 15/100\n",
      " - 0s - loss: 0.9273 - acc: 0.5366\n",
      "Epoch 16/100\n",
      " - 0s - loss: 0.9254 - acc: 0.5396\n",
      "Epoch 17/100\n",
      " - 0s - loss: 0.9236 - acc: 0.5366\n",
      "Epoch 18/100\n",
      " - 0s - loss: 0.9220 - acc: 0.5351\n",
      "Epoch 19/100\n",
      " - 0s - loss: 0.9208 - acc: 0.5374\n",
      "Epoch 20/100\n",
      " - 0s - loss: 0.9195 - acc: 0.5328\n",
      "Epoch 21/100\n",
      " - 0s - loss: 0.9185 - acc: 0.5389\n",
      "Epoch 22/100\n",
      " - 0s - loss: 0.9175 - acc: 0.5411\n",
      "Epoch 23/100\n",
      " - 0s - loss: 0.9165 - acc: 0.5419\n",
      "Epoch 24/100\n",
      " - 0s - loss: 0.9158 - acc: 0.5479\n",
      "Epoch 25/100\n",
      " - 0s - loss: 0.9151 - acc: 0.5419\n",
      "Epoch 26/100\n",
      " - 0s - loss: 0.9143 - acc: 0.5472\n",
      "Epoch 27/100\n",
      " - 0s - loss: 0.9137 - acc: 0.5479\n",
      "Epoch 28/100\n",
      " - 0s - loss: 0.9131 - acc: 0.5479\n",
      "Epoch 29/100\n",
      " - 0s - loss: 0.9126 - acc: 0.5472\n",
      "Epoch 30/100\n",
      " - 0s - loss: 0.9121 - acc: 0.5494\n",
      "Epoch 31/100\n",
      " - 0s - loss: 0.9115 - acc: 0.5494\n",
      "Epoch 32/100\n",
      " - 0s - loss: 0.9110 - acc: 0.5472\n",
      "Epoch 33/100\n",
      " - 0s - loss: 0.9106 - acc: 0.5472\n",
      "Epoch 34/100\n",
      " - 0s - loss: 0.9101 - acc: 0.5509\n",
      "Epoch 35/100\n",
      " - 0s - loss: 0.9097 - acc: 0.5509\n",
      "Epoch 36/100\n",
      " - 0s - loss: 0.9093 - acc: 0.5517\n",
      "Epoch 37/100\n",
      " - 0s - loss: 0.9089 - acc: 0.5502\n",
      "Epoch 38/100\n",
      " - 0s - loss: 0.9085 - acc: 0.5487\n",
      "Epoch 39/100\n",
      " - 0s - loss: 0.9080 - acc: 0.5487\n",
      "Epoch 40/100\n",
      " - 0s - loss: 0.9078 - acc: 0.5487\n",
      "Epoch 41/100\n",
      " - 0s - loss: 0.9074 - acc: 0.5487\n",
      "Epoch 42/100\n",
      " - 0s - loss: 0.9070 - acc: 0.5509\n",
      "Epoch 43/100\n",
      " - 0s - loss: 0.9067 - acc: 0.5509\n",
      "Epoch 44/100\n",
      " - 0s - loss: 0.9063 - acc: 0.5525\n",
      "Epoch 45/100\n",
      " - 0s - loss: 0.9061 - acc: 0.5502\n",
      "Epoch 46/100\n",
      " - 0s - loss: 0.9058 - acc: 0.5525\n",
      "Epoch 47/100\n",
      " - 0s - loss: 0.9055 - acc: 0.5509\n",
      "Epoch 48/100\n",
      " - 0s - loss: 0.9053 - acc: 0.5525\n",
      "Epoch 49/100\n",
      " - 0s - loss: 0.9050 - acc: 0.5532\n",
      "Epoch 50/100\n",
      " - 0s - loss: 0.9047 - acc: 0.5517\n",
      "Epoch 51/100\n",
      " - 0s - loss: 0.9044 - acc: 0.5540\n",
      "Epoch 52/100\n",
      " - 0s - loss: 0.9041 - acc: 0.5562\n",
      "Epoch 53/100\n",
      " - 0s - loss: 0.9039 - acc: 0.5547\n",
      "Epoch 54/100\n",
      " - 0s - loss: 0.9036 - acc: 0.5532\n",
      "Epoch 55/100\n",
      " - 0s - loss: 0.9034 - acc: 0.5517\n",
      "Epoch 56/100\n",
      " - 0s - loss: 0.9032 - acc: 0.5532\n",
      "Epoch 57/100\n",
      " - 0s - loss: 0.9030 - acc: 0.5517\n",
      "Epoch 58/100\n",
      " - 0s - loss: 0.9028 - acc: 0.5509\n",
      "Epoch 59/100\n",
      " - 0s - loss: 0.9026 - acc: 0.5577\n",
      "Epoch 60/100\n",
      " - 0s - loss: 0.9024 - acc: 0.5547\n",
      "Epoch 61/100\n",
      " - 0s - loss: 0.9022 - acc: 0.5540\n",
      "Epoch 62/100\n",
      " - 0s - loss: 0.9019 - acc: 0.5562\n",
      "Epoch 63/100\n",
      " - 0s - loss: 0.9018 - acc: 0.5540\n",
      "Epoch 64/100\n",
      " - 0s - loss: 0.9016 - acc: 0.5555\n",
      "Epoch 65/100\n",
      " - 0s - loss: 0.9013 - acc: 0.5540\n",
      "Epoch 66/100\n",
      " - 0s - loss: 0.9012 - acc: 0.5555\n",
      "Epoch 67/100\n",
      " - 0s - loss: 0.9010 - acc: 0.5577\n",
      "Epoch 68/100\n",
      " - 0s - loss: 0.9008 - acc: 0.5585\n",
      "Epoch 69/100\n",
      " - 0s - loss: 0.9006 - acc: 0.5562\n",
      "Epoch 70/100\n",
      " - 0s - loss: 0.9005 - acc: 0.5562\n",
      "Epoch 71/100\n",
      " - 0s - loss: 0.9003 - acc: 0.5577\n",
      "Epoch 72/100\n",
      " - 0s - loss: 0.9002 - acc: 0.5600\n",
      "Epoch 73/100\n",
      " - 0s - loss: 0.9000 - acc: 0.5585\n",
      "Epoch 74/100\n",
      " - 0s - loss: 0.8997 - acc: 0.5585\n",
      "Epoch 75/100\n",
      " - 0s - loss: 0.8996 - acc: 0.5577\n",
      "Epoch 76/100\n",
      " - 0s - loss: 0.8995 - acc: 0.5585\n",
      "Epoch 77/100\n",
      " - 0s - loss: 0.8993 - acc: 0.5570\n",
      "Epoch 78/100\n",
      " - 0s - loss: 0.8992 - acc: 0.5570\n",
      "Epoch 79/100\n",
      " - 0s - loss: 0.8990 - acc: 0.5555\n",
      "Epoch 80/100\n",
      " - 0s - loss: 0.8987 - acc: 0.5577\n",
      "Epoch 81/100\n",
      " - 0s - loss: 0.8986 - acc: 0.5555\n",
      "Epoch 82/100\n",
      " - 0s - loss: 0.8985 - acc: 0.5615\n",
      "Epoch 83/100\n",
      " - 0s - loss: 0.8983 - acc: 0.5592\n",
      "Epoch 84/100\n",
      " - 0s - loss: 0.8981 - acc: 0.5600\n",
      "Epoch 85/100\n",
      " - 0s - loss: 0.8980 - acc: 0.5577\n",
      "Epoch 86/100\n",
      " - 0s - loss: 0.8979 - acc: 0.5600\n",
      "Epoch 87/100\n",
      " - 0s - loss: 0.8977 - acc: 0.5615\n",
      "Epoch 88/100\n",
      " - 0s - loss: 0.8975 - acc: 0.5592\n",
      "Epoch 89/100\n",
      " - 0s - loss: 0.8973 - acc: 0.5623\n",
      "Epoch 90/100\n",
      " - 0s - loss: 0.8973 - acc: 0.5608\n",
      "Epoch 91/100\n",
      " - 0s - loss: 0.8970 - acc: 0.5615\n",
      "Epoch 92/100\n",
      " - 0s - loss: 0.8971 - acc: 0.5577\n",
      "Epoch 93/100\n",
      " - 0s - loss: 0.8968 - acc: 0.5615\n",
      "Epoch 94/100\n",
      " - 0s - loss: 0.8967 - acc: 0.5623\n",
      "Epoch 95/100\n",
      " - 0s - loss: 0.8967 - acc: 0.5600\n",
      "Epoch 96/100\n",
      " - 0s - loss: 0.8964 - acc: 0.5615\n",
      "Epoch 97/100\n",
      " - 0s - loss: 0.8963 - acc: 0.5623\n",
      "Epoch 98/100\n",
      " - 0s - loss: 0.8961 - acc: 0.5615\n",
      "Epoch 99/100\n",
      " - 0s - loss: 0.8960 - acc: 0.5608\n",
      "Epoch 100/100\n",
      " - 0s - loss: 0.8959 - acc: 0.5608\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x11b5f6f60>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit([train_liv[:,None], train_edu[:,None], train_x], train_y, nb_epoch=100, verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_3 (Dense)              (None, 4)                 20        \n",
      "=================================================================\n",
      "Total params: 20\n",
      "Trainable params: 20\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "dense_x.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, 1, 4)              16        \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 4)                 0         \n",
      "=================================================================\n",
      "Total params: 16\n",
      "Trainable params: 16\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "encoder_liv.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "merge_1 (Merge)              (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 12)                156       \n",
      "_________________________________________________________________\n",
      "activation_3 (Activation)    (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 3)                 39        \n",
      "_________________________________________________________________\n",
      "activation_4 (Activation)    (None, 3)                 0         \n",
      "=================================================================\n",
      "Total params: 247\n",
      "Trainable params: 247\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 4)\n",
      "(4, 4)\n",
      "(4, 4)\n",
      "(12, 12)\n",
      "(12, 3)\n"
     ]
    }
   ],
   "source": [
    "for w in model.get_weights():\n",
    "    if w:\n",
    "        print(w[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[array([[ 0.11730162,  0.2315124 ,  0.3354178 ,  0.3388376 ],\n",
       "         [ 0.07157213,  0.01113452,  0.16499558,  0.1336401 ],\n",
       "         [-0.12199867, -0.08454118, -0.0483714 ,  0.03888453],\n",
       "         [-0.08046792, -0.24010903, -0.06721216, -0.03922761]],\n",
       "        dtype=float32)],\n",
       " [],\n",
       " [array([[ 0.42741278, -0.0315637 , -0.09533015, -0.24039595],\n",
       "         [ 0.074214  ,  0.17136088,  0.08966726, -0.18540785],\n",
       "         [-0.03115179,  0.02272942,  0.16216515,  0.06415284],\n",
       "         [-0.18743801, -0.12739968,  0.17538098,  0.45354202]],\n",
       "        dtype=float32)],\n",
       " [],\n",
       " [array([[-0.30609292,  0.10283919, -1.0234545 , -0.0718132 ],\n",
       "         [-0.42188373, -1.0565518 ,  0.7799135 , -0.57389927],\n",
       "         [-0.31404954,  0.05534157,  0.3076612 , -0.18355557],\n",
       "         [ 0.68581975,  0.3141965 ,  0.49108952,  0.31939846]],\n",
       "        dtype=float32),\n",
       "  array([ 0.02104649, -0.0327523 ,  0.00743179,  0.08350972], dtype=float32)],\n",
       " [array([[-1.20145291e-01, -2.46437207e-01, -3.31275493e-01,\n",
       "          -2.20140591e-01,  3.22411954e-02, -1.85993418e-01,\n",
       "           1.50156682e-02,  5.86777270e-01,  3.74501377e-01,\n",
       "           6.00176156e-01, -2.33102456e-01,  1.97912589e-01],\n",
       "         [-5.84251247e-02,  2.84687608e-01,  2.09113523e-01,\n",
       "           6.88999832e-01,  1.59126058e-01, -6.32414639e-01,\n",
       "           5.85778318e-02,  2.36010000e-01,  4.87647623e-01,\n",
       "           7.10523129e-01, -2.98705012e-01, -6.73056424e-01],\n",
       "         [ 4.99374494e-02,  5.27223825e-01, -4.48906809e-01,\n",
       "           5.43242931e-01,  3.15158628e-02, -3.49278450e-01,\n",
       "           3.16208601e-01,  1.79898441e-01,  4.02369827e-01,\n",
       "           4.23717111e-01,  4.01757985e-01,  3.27482074e-01],\n",
       "         [ 1.68276043e-03,  2.24269614e-01, -1.55253738e-01,\n",
       "           1.37269288e-01,  1.37560502e-01,  6.06410444e-01,\n",
       "           4.23902065e-01,  2.36994013e-01,  1.34622231e-01,\n",
       "           1.31565452e-01,  4.10773098e-01, -4.18158025e-02],\n",
       "         [ 1.42986253e-01,  2.84708172e-01, -1.76201746e-01,\n",
       "          -7.45935440e-02,  1.47780567e-01, -4.74885345e-01,\n",
       "           5.34582853e-01,  6.59194946e-01,  6.03271089e-02,\n",
       "           2.65143346e-02,  6.03946209e-01, -7.07837999e-01],\n",
       "         [-5.67550713e-04,  5.80894232e-01,  4.77583051e-01,\n",
       "          -1.44562036e-01, -1.23427706e-02, -4.37089764e-02,\n",
       "           5.80135405e-01, -7.59563088e-01, -4.12246555e-01,\n",
       "          -1.29499286e-01, -3.76466185e-01, -1.08759917e-01],\n",
       "         [-6.76222324e-01, -3.34644705e-01, -4.17766720e-01,\n",
       "           2.54684478e-01,  5.39509356e-01,  6.20218337e-01,\n",
       "           3.36583495e-01,  1.41444914e-02, -2.79466182e-01,\n",
       "           2.39985809e-01,  4.79011893e-01,  7.59531260e-01],\n",
       "         [-3.45097154e-01, -4.46005911e-01,  1.11874945e-01,\n",
       "           6.93398491e-02,  2.85524666e-01,  7.27702677e-01,\n",
       "          -1.73978016e-01, -1.47991285e-01, -6.14667714e-01,\n",
       "          -3.82844925e-01, -7.19305396e-01,  1.07370567e+00],\n",
       "         [-1.70211956e-01,  4.82094437e-01,  5.11124790e-01,\n",
       "          -3.85133445e-01, -2.59445637e-01,  3.24229687e-01,\n",
       "          -2.08615094e-01, -5.63227236e-01,  2.33637407e-01,\n",
       "          -3.32046151e-01,  1.02926649e-01, -2.25125030e-01],\n",
       "         [-3.88037235e-01,  5.18728435e-01,  3.71372789e-01,\n",
       "           4.70427543e-01, -8.56025144e-02, -3.44712168e-01,\n",
       "          -4.05267291e-02, -2.89138436e-01,  4.28172737e-01,\n",
       "          -1.43907577e-01, -1.60402879e-02, -4.50862139e-01],\n",
       "         [ 9.97807682e-02, -5.70632160e-01, -3.15305412e-01,\n",
       "          -4.78042305e-01, -2.68817067e-01, -2.30937228e-02,\n",
       "           5.85884869e-01, -1.54629797e-01,  2.56999403e-01,\n",
       "          -3.21778595e-01, -6.45531178e-01, -3.42596471e-01],\n",
       "         [ 2.07257673e-01,  2.92181492e-01, -4.00053198e-03,\n",
       "           2.19944581e-01, -2.29526907e-01,  5.58538735e-03,\n",
       "           2.25658149e-01, -4.17933583e-01,  6.22005463e-02,\n",
       "          -8.05840790e-02,  2.19203681e-01, -1.41020834e-01]], dtype=float32),\n",
       "  array([-0.07727586, -0.00281815, -0.04617415, -0.00140562, -0.00134546,\n",
       "          0.16837443,  0.09992734, -0.08112685,  0.02305767,  0.04437137,\n",
       "          0.16232194,  0.13759789], dtype=float32)],\n",
       " [],\n",
       " [array([[ 0.5766323 ,  0.185857  , -0.0387473 ],\n",
       "         [ 0.6423362 , -0.1846978 ,  0.04110023],\n",
       "         [ 0.52129775,  0.07601085, -0.67859   ],\n",
       "         [ 0.26691735, -0.44697848,  0.01138065],\n",
       "         [-0.35870725, -0.10113026, -0.5956551 ],\n",
       "         [-0.7656568 , -0.08692111, -0.19708055],\n",
       "         [-0.03473576, -0.6235873 ,  0.7045059 ],\n",
       "         [ 0.64653885,  0.1463624 , -0.03452552],\n",
       "         [ 0.7291127 , -0.49087396, -0.07816873],\n",
       "         [ 0.65638536, -0.10689628, -0.22387852],\n",
       "         [ 0.13592255, -0.5635796 , -0.58162445],\n",
       "         [-0.37770072,  0.28712225,  0.04063358]], dtype=float32),\n",
       "  array([-0.06061557, -0.06954954,  0.11779507], dtype=float32)],\n",
       " []]"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = model.get_weights()\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "148/148 [==============================] - 0s 285us/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.8468180298805237, 0.6081081032752991]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate([test_liv[:,None], test_edu[:,None], test_x],test_y, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.25250018, 0.31250244, 0.43499738],\n",
       "       [0.8849396 , 0.0538322 , 0.06122812],\n",
       "       [0.24710615, 0.16935235, 0.58354145],\n",
       "       [0.42179266, 0.33125114, 0.24695621],\n",
       "       [0.8215561 , 0.08922405, 0.08921982]], dtype=float32)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p = model.predict([test_liv[:,None], test_edu[:,None], test_x], batch_size=256)\n",
    "p[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "merge_1 (Merge)              (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 12)                156       \n",
      "_________________________________________________________________\n",
      "activation_3 (Activation)    (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 3)                 39        \n",
      "_________________________________________________________________\n",
      "activation_4 (Activation)    (None, 3)                 0         \n",
      "=================================================================\n",
      "Total params: 247\n",
      "Trainable params: 247\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:4: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n",
      "  after removing the cwd from sys.path.\n",
      "/usr/local/lib/python3.6/site-packages/keras/models.py:942: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
      "  warnings.warn('The `nb_epoch` argument in `fit` '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "1325/1325 [==============================] - 0s 282us/step - loss: 0.6501 - acc: 0.6629\n",
      "Epoch 2/100\n",
      "1325/1325 [==============================] - 0s 60us/step - loss: 0.6353 - acc: 0.6667\n",
      "Epoch 3/100\n",
      "1325/1325 [==============================] - 0s 55us/step - loss: 0.6303 - acc: 0.6667\n",
      "Epoch 4/100\n",
      "1325/1325 [==============================] - 0s 58us/step - loss: 0.6274 - acc: 0.6667\n",
      "Epoch 5/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6254 - acc: 0.6667\n",
      "Epoch 6/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6240 - acc: 0.6664\n",
      "Epoch 7/100\n",
      "1325/1325 [==============================] - 0s 50us/step - loss: 0.6228 - acc: 0.6664: 0s - loss: 0.6237 - acc: 0.666\n",
      "Epoch 8/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6219 - acc: 0.6664\n",
      "Epoch 9/100\n",
      "1325/1325 [==============================] - 0s 53us/step - loss: 0.6212 - acc: 0.6662\n",
      "Epoch 10/100\n",
      "1325/1325 [==============================] - 0s 51us/step - loss: 0.6207 - acc: 0.6659\n",
      "Epoch 11/100\n",
      "1325/1325 [==============================] - 0s 50us/step - loss: 0.6202 - acc: 0.6662\n",
      "Epoch 12/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6198 - acc: 0.6662\n",
      "Epoch 13/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6194 - acc: 0.6659\n",
      "Epoch 14/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6191 - acc: 0.6659\n",
      "Epoch 15/100\n",
      "1325/1325 [==============================] - 0s 49us/step - loss: 0.6187 - acc: 0.6662\n",
      "Epoch 16/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6184 - acc: 0.6662\n",
      "Epoch 17/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6181 - acc: 0.6662\n",
      "Epoch 18/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6178 - acc: 0.6662\n",
      "Epoch 19/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6175 - acc: 0.6659\n",
      "Epoch 20/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6171 - acc: 0.6662\n",
      "Epoch 21/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6168 - acc: 0.6662\n",
      "Epoch 22/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6165 - acc: 0.6664\n",
      "Epoch 23/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6162 - acc: 0.6659\n",
      "Epoch 24/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6159 - acc: 0.6662\n",
      "Epoch 25/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6156 - acc: 0.6662\n",
      "Epoch 26/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6153 - acc: 0.6662\n",
      "Epoch 27/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6150 - acc: 0.6662\n",
      "Epoch 28/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6147 - acc: 0.6664\n",
      "Epoch 29/100\n",
      "1325/1325 [==============================] - 0s 43us/step - loss: 0.6145 - acc: 0.6662\n",
      "Epoch 30/100\n",
      "1325/1325 [==============================] - 0s 50us/step - loss: 0.6142 - acc: 0.6662\n",
      "Epoch 31/100\n",
      "1325/1325 [==============================] - 0s 49us/step - loss: 0.6139 - acc: 0.6662\n",
      "Epoch 32/100\n",
      "1325/1325 [==============================] - 0s 43us/step - loss: 0.6136 - acc: 0.6662\n",
      "Epoch 33/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6133 - acc: 0.6662\n",
      "Epoch 34/100\n",
      "1325/1325 [==============================] - 0s 60us/step - loss: 0.6130 - acc: 0.6662\n",
      "Epoch 35/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6127 - acc: 0.6659\n",
      "Epoch 36/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6123 - acc: 0.6659\n",
      "Epoch 37/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6120 - acc: 0.6662\n",
      "Epoch 38/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6117 - acc: 0.6664\n",
      "Epoch 39/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6114 - acc: 0.6664\n",
      "Epoch 40/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6111 - acc: 0.6662\n",
      "Epoch 41/100\n",
      "1325/1325 [==============================] - 0s 51us/step - loss: 0.6108 - acc: 0.6664\n",
      "Epoch 42/100\n",
      "1325/1325 [==============================] - 0s 57us/step - loss: 0.6105 - acc: 0.6662\n",
      "Epoch 43/100\n",
      "1325/1325 [==============================] - 0s 57us/step - loss: 0.6102 - acc: 0.6662\n",
      "Epoch 44/100\n",
      "1325/1325 [==============================] - 0s 55us/step - loss: 0.6099 - acc: 0.6662\n",
      "Epoch 45/100\n",
      "1325/1325 [==============================] - 0s 49us/step - loss: 0.6096 - acc: 0.6657\n",
      "Epoch 46/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6094 - acc: 0.6659\n",
      "Epoch 47/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6091 - acc: 0.6654\n",
      "Epoch 48/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6089 - acc: 0.6657\n",
      "Epoch 49/100\n",
      "1325/1325 [==============================] - 0s 50us/step - loss: 0.6086 - acc: 0.6657\n",
      "Epoch 50/100\n",
      "1325/1325 [==============================] - 0s 53us/step - loss: 0.6083 - acc: 0.6662\n",
      "Epoch 51/100\n",
      "1325/1325 [==============================] - 0s 53us/step - loss: 0.6081 - acc: 0.6659\n",
      "Epoch 52/100\n",
      "1325/1325 [==============================] - 0s 50us/step - loss: 0.6078 - acc: 0.6654\n",
      "Epoch 53/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6076 - acc: 0.6652\n",
      "Epoch 54/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6074 - acc: 0.6654\n",
      "Epoch 55/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6072 - acc: 0.6654\n",
      "Epoch 56/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6069 - acc: 0.6657\n",
      "Epoch 57/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6067 - acc: 0.6677\n",
      "Epoch 58/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6065 - acc: 0.6672\n",
      "Epoch 59/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6063 - acc: 0.6679\n",
      "Epoch 60/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6061 - acc: 0.6672\n",
      "Epoch 61/100\n",
      "1325/1325 [==============================] - 0s 52us/step - loss: 0.6059 - acc: 0.6697\n",
      "Epoch 62/100\n",
      "1325/1325 [==============================] - 0s 49us/step - loss: 0.6057 - acc: 0.6699\n",
      "Epoch 63/100\n",
      "1325/1325 [==============================] - 0s 42us/step - loss: 0.6055 - acc: 0.6704\n",
      "Epoch 64/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6054 - acc: 0.6732\n",
      "Epoch 65/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6052 - acc: 0.6742\n",
      "Epoch 66/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6050 - acc: 0.6745\n",
      "Epoch 67/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6048 - acc: 0.6747\n",
      "Epoch 68/100\n",
      "1325/1325 [==============================] - 0s 53us/step - loss: 0.6046 - acc: 0.6742\n",
      "Epoch 69/100\n",
      "1325/1325 [==============================] - 0s 51us/step - loss: 0.6045 - acc: 0.6752\n",
      "Epoch 70/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6043 - acc: 0.6742\n",
      "Epoch 71/100\n",
      "1325/1325 [==============================] - 0s 49us/step - loss: 0.6042 - acc: 0.6745\n",
      "Epoch 72/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6040 - acc: 0.6742\n",
      "Epoch 73/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6039 - acc: 0.6750\n",
      "Epoch 74/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6037 - acc: 0.6742\n",
      "Epoch 75/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6035 - acc: 0.6747\n",
      "Epoch 76/100\n",
      "1325/1325 [==============================] - 0s 43us/step - loss: 0.6034 - acc: 0.6750\n",
      "Epoch 77/100\n",
      "1325/1325 [==============================] - 0s 43us/step - loss: 0.6032 - acc: 0.6757\n",
      "Epoch 78/100\n",
      "1325/1325 [==============================] - 0s 46us/step - loss: 0.6031 - acc: 0.6757\n",
      "Epoch 79/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6030 - acc: 0.6755\n",
      "Epoch 80/100\n",
      "1325/1325 [==============================] - 0s 49us/step - loss: 0.6028 - acc: 0.6752\n",
      "Epoch 81/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6027 - acc: 0.6755\n",
      "Epoch 82/100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6025 - acc: 0.6745\n",
      "Epoch 83/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6024 - acc: 0.6747\n",
      "Epoch 84/100\n",
      "1325/1325 [==============================] - 0s 49us/step - loss: 0.6023 - acc: 0.6747\n",
      "Epoch 85/100\n",
      "1325/1325 [==============================] - 0s 52us/step - loss: 0.6022 - acc: 0.6757\n",
      "Epoch 86/100\n",
      "1325/1325 [==============================] - 0s 51us/step - loss: 0.6020 - acc: 0.6767\n",
      "Epoch 87/100\n",
      "1325/1325 [==============================] - 0s 50us/step - loss: 0.6019 - acc: 0.6765\n",
      "Epoch 88/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6018 - acc: 0.6772\n",
      "Epoch 89/100\n",
      "1325/1325 [==============================] - 0s 47us/step - loss: 0.6017 - acc: 0.6775\n",
      "Epoch 90/100\n",
      "1325/1325 [==============================] - 0s 42us/step - loss: 0.6015 - acc: 0.6790\n",
      "Epoch 91/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6014 - acc: 0.6785\n",
      "Epoch 92/100\n",
      "1325/1325 [==============================] - 0s 43us/step - loss: 0.6013 - acc: 0.6790\n",
      "Epoch 93/100\n",
      "1325/1325 [==============================] - 0s 44us/step - loss: 0.6012 - acc: 0.6795\n",
      "Epoch 94/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6011 - acc: 0.6795\n",
      "Epoch 95/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6010 - acc: 0.6808\n",
      "Epoch 96/100\n",
      "1325/1325 [==============================] - 0s 53us/step - loss: 0.6009 - acc: 0.6803\n",
      "Epoch 97/100\n",
      "1325/1325 [==============================] - 0s 51us/step - loss: 0.6008 - acc: 0.6815\n",
      "Epoch 98/100\n",
      "1325/1325 [==============================] - 0s 48us/step - loss: 0.6007 - acc: 0.6810\n",
      "Epoch 99/100\n",
      "1325/1325 [==============================] - 0s 45us/step - loss: 0.6006 - acc: 0.6805\n",
      "Epoch 100/100\n",
      "1325/1325 [==============================] - 0s 43us/step - loss: 0.6005 - acc: 0.6808\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x11b9c0ac8>"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Dense(4, input_dim=train_x.shape[1]))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Dense(output_dim=3))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "model.compile(optimizer='adagrad', loss='binary_crossentropy', metrics=['accuracy'])\n",
    "model.fit(train_x, train_y, nb_epoch=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(test_x,test_y,batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "latex_envs": {
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 0
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
